{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ur8xi4C7S06n"
      },
      "outputs": [],
      "source": [
        "# Copyright 2025 Google LLC\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JAPoU8Sm5E6e"
      },
      "source": [
        "# Automating Income Taxes with Gemini\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/gemini/use-cases/document-processing/tax_automation.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg\" alt=\"Google Colaboratory logo\"><br> Open in Colab\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Fgemini%2Fuse-cases%2Fdocument-processing%2Ftax_automation.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN\" alt=\"Google Cloud Colab Enterprise logo\"><br> Open in Colab Enterprise\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/gemini/use-cases/document-processing/tax_automation.ipynb\">\n",
        "      <img src=\"https://www.gstatic.com/images/branding/gcpiconscolors/vertexai/v1/32px.svg\" alt=\"Vertex AI logo\"><br> Open in Vertex AI Workbench\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://console.cloud.google.com/bigquery/import?url=https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/use-cases/document-processing/tax_automation.ipynb\">\n",
        "      <img src=\"https://www.gstatic.com/images/branding/gcpiconscolors/bigquery/v1/32px.svg\" alt=\"BigQuery Studio logo\"><br> Open in BigQuery Studio\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/use-cases/document-processing/tax_automation.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://upload.wikimedia.org/wikipedia/commons/9/91/Octicons-mark-github.svg\" alt=\"GitHub logo\"><br> View on GitHub\n",
        "    </a>\n",
        "  </td>\n",
        "</table>\n",
        "\n",
        "<div style=\"clear: both;\"></div>\n",
        "\n",
        "<b>Share to:</b>\n",
        "\n",
        "<a href=\"https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/use-cases/document-processing/tax_automation.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg\" alt=\"LinkedIn logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://bsky.app/intent/compose?text=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/use-cases/document-processing/tax_automation.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg\" alt=\"Bluesky logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://twitter.com/intent/tweet?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/use-cases/document-processing/tax_automation.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/53/X_logo_2023_original.svg\" alt=\"X logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://reddit.com/submit?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/use-cases/document-processing/tax_automation.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png\" alt=\"Reddit logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/use-cases/document-processing/tax_automation.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg\" alt=\"Facebook logo\">\n",
        "</a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "84f0f73a0f76"
      },
      "source": [
        "| Author |\n",
        "| --- |\n",
        "| [Holt Skinner](https://github.com/holtskinner) |"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tvgnzT1CKxrO"
      },
      "source": [
        "## Overview\n",
        "\n",
        "Back in 2022, I wrote a [Google Cloud Blog post](https://cloud.google.com/blog/topics/developers-practitioners/automating-income-taxes-document-ai) about automating income tax preparation using [Document AI](https://cloud.google.com/document-ai/docs/overview).\n",
        "\n",
        "This demo used the [Lending processors](https://cloud.google.com/blog/products/ai-machine-learning/lending-docai-fast-tracks-the-home-loan-process) to extract data from W-2 and 1099 PDFs and calculate the total tax owed.\n",
        "\n",
        "In the world of Generative AI models like [Gemini](https://blog.google/technology/ai/google-gemini-ai/), it's possible to create the same document processing pipeline in a more efficient manner.\n",
        "\n",
        "In this notebook, we'll create a document understanding pipeline on some sample tax documents to:\n",
        "\n",
        "- Classify the type of document (W-2, 1099-DIV, 1099-INT, 1099-MISC, 1099-NEC)\n",
        "- Extract key fields based on the document type.\n",
        "\n",
        "These are the sample documents we will use:\n",
        "\n",
        "- [2020 Form 1099-DIV](https://storage.googleapis.com/cloud-samples-data/documentai/LendingDocAI/1099-DIV%20Parser/2020%20Form%201099-DIV%20-%20Anastasia%20Hodges.pdf)\n",
        "- [2020 Form 1099-INT](https://storage.googleapis.com/cloud-samples-data/documentai/LendingDocAI/1099-INT%20Parser/2020%20Form%201099-INT%20-%20Anastasia%20Hodges.pdf)\n",
        "- [2020 Form W-2](https://storage.googleapis.com/cloud-samples-data/documentai/LendingDocAI/W2Parser/2020/2020%20Form%20W-2%20-%20Anastasia%20Hodges.pdf)\n",
        "\n",
        "> Disclaimer: This is **NOT** financial advice, for educational purposes only!"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "61RBz8LLbxCR"
      },
      "source": [
        "## Get started"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "No17Cw5hgx12"
      },
      "source": [
        "### Install Google Gen AI SDK for Python\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tFy3H3aPgx12"
      },
      "outputs": [],
      "source": [
        "%pip install --upgrade --quiet google-genai"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dmWOrTJ3gx13"
      },
      "source": [
        "### Authenticate your notebook environment (Colab only)\n",
        "\n",
        "If you're running this notebook on Google Colab, run the cell below to authenticate your environment."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NyKGtVQjgx13"
      },
      "outputs": [],
      "source": [
        "import sys\n",
        "\n",
        "if \"google.colab\" in sys.modules:\n",
        "    from google.colab import auth\n",
        "\n",
        "    auth.authenticate_user()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DF4l8DTdWgPY"
      },
      "source": [
        "### Set Google Cloud project information and create client\n",
        "\n",
        "To get started using Vertex AI, you must have an existing Google Cloud project and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com).\n",
        "\n",
        "Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Nqwi-5ufWp_B"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "\n",
        "from google import genai\n",
        "\n",
        "PROJECT_ID = \"[your-project-id]\"  # @param {type: \"string\"}\n",
        "if not PROJECT_ID or PROJECT_ID == \"[your-project-id]\":\n",
        "    PROJECT_ID = str(os.environ.get(\"GOOGLE_CLOUD_PROJECT\"))\n",
        "\n",
        "LOCATION = os.environ.get(\"GOOGLE_CLOUD_REGION\", \"us-central1\")\n",
        "\n",
        "client = genai.Client(vertexai=True, project=PROJECT_ID, location=LOCATION)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "C8aMIcn9mEWt"
      },
      "source": [
        "### Import libraries\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rrgrbhPmmEWt"
      },
      "outputs": [],
      "source": [
        "from enum import Enum\n",
        "\n",
        "from IPython.display import display\n",
        "from google.genai.types import GenerateContentConfig, Part\n",
        "import pandas as pd\n",
        "from pydantic import BaseModel, Field\n",
        "\n",
        "pd.set_option(\"display.max_colwidth\", None)\n",
        "PDF_MIME_TYPE = \"application/pdf\"\n",
        "JSON_MIME_TYPE = \"application/json\"\n",
        "ENUM_MIME_TYPE = \"text/x.enum\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e43229f3ad4f"
      },
      "source": [
        "### Load the Gemini 2.0 Flash model\n",
        "\n",
        "To learn more about all [Gemini models on Vertex AI](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#gemini-models)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cf93d5f0ce00"
      },
      "outputs": [],
      "source": [
        "MODEL_ID = \"gemini-2.0-flash\"  # @param {type: \"string\"}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "53d3d02d82b0"
      },
      "source": [
        "Create a pandas DataFrame to contain the data."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "43c76a9f1ea4"
      },
      "outputs": [],
      "source": [
        "tax_documents = pd.DataFrame(\n",
        "    {\n",
        "        \"file_uri\": [\n",
        "            \"https://storage.googleapis.com/cloud-samples-data/documentai/LendingDocAI/1099-DIV%20Parser/2020%20Form%201099-DIV%20-%20Anastasia%20Hodges.pdf\",\n",
        "            \"https://storage.googleapis.com/cloud-samples-data/documentai/LendingDocAI/1099-INT%20Parser/2020%20Form%201099-INT%20-%20Anastasia%20Hodges.pdf\",\n",
        "            \"https://storage.googleapis.com/cloud-samples-data/documentai/LendingDocAI/W2Parser/2020/2020%20Form%20W-2%20-%20Anastasia%20Hodges.pdf\",\n",
        "        ]\n",
        "    }\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EdvJRUWRNGHE"
      },
      "source": [
        "## Classify Documents\n",
        "\n",
        "First, we need to classify each of our documents.\n",
        "\n",
        "We will create an `Enum` class including each type of document."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "955ffce857e3"
      },
      "outputs": [],
      "source": [
        "class DocumentType(Enum):\n",
        "    W_2 = \"W-2\"\n",
        "    _1099_DIV = \"1099-DIV\"\n",
        "    _1099_INT = \"1099-INT\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "d3bd3714f764"
      },
      "source": [
        "Next, we will send each document to the Gemini model with a classification prompt."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "649cb3dce660"
      },
      "outputs": [],
      "source": [
        "def classify_document(file_uri: str) -> Enum:\n",
        "    response = client.models.generate_content(\n",
        "        model=MODEL_ID,\n",
        "        contents=[\n",
        "            \"Classify the following document.\",\n",
        "            Part.from_uri(\n",
        "                file_uri=file_uri,\n",
        "                mime_type=PDF_MIME_TYPE,\n",
        "            ),\n",
        "        ],\n",
        "        config=GenerateContentConfig(\n",
        "            system_instruction=\"\"\"You are a document classification specialist. Given a document, your task is to find which category the document belongs to from the document categories provided in the schema.\"\"\",\n",
        "            temperature=0,\n",
        "            response_schema=DocumentType,\n",
        "            response_mime_type=ENUM_MIME_TYPE,\n",
        "        ),\n",
        "    )\n",
        "    return response.parsed\n",
        "\n",
        "\n",
        "tax_documents[\"classification\"] = tax_documents[\"file_uri\"].apply(classify_document)\n",
        "display(tax_documents)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "843b123600c7"
      },
      "source": [
        "## Extract Data\n",
        "\n",
        "In order to extract the fields from each of these document types, we will need to create Pydantic classes containing the fields to extract for each type. Then we will create a mapping of the classification `Enum` to the Pydantic classes.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9c8d93b1ee61"
      },
      "source": [
        "### Create Pydantic classes\n",
        "\n",
        "> Note: These Pydantic models were created using Gemini with the following prompt:\n",
        "> \n",
        "> `Create a Pydantic class from BaseModel to contain values to extract from a [Document Type]`"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fb3fd97752c4"
      },
      "outputs": [],
      "source": [
        "class FormW2(BaseModel):\n",
        "    \"\"\"\n",
        "    Pydantic class to represent data extracted from a Form W-2 (Wage and Tax Statement).\n",
        "    \"\"\"\n",
        "\n",
        "    employee_ssn: str = Field(..., description=\"Employee's Social Security Number\")\n",
        "    employer_ein: str = Field(\n",
        "        ..., description=\"Employer's Employer Identification Number\"\n",
        "    )\n",
        "    control_number: str | None = Field(\n",
        "        None, description=\"Employer's Control Number (Optional)\"\n",
        "    )\n",
        "    wages_tips_other_compensation: float = Field(\n",
        "        ..., description=\"Total Wages, tips, and other compensation\"\n",
        "    )\n",
        "    federal_income_tax_withheld: float = Field(\n",
        "        ..., description=\"Federal income tax withheld from wages\"\n",
        "    )\n",
        "    social_security_wages: float = Field(..., description=\"Social Security wages\")\n",
        "    social_security_tax_withheld: float = Field(\n",
        "        ..., description=\"Social Security tax withheld\"\n",
        "    )\n",
        "    medicare_wages_and_tips: float = Field(..., description=\"Medicare wages and tips\")\n",
        "    medicare_tax_withheld: float = Field(..., description=\"Medicare tax withheld\")\n",
        "    dependent_care_benefits: float | None = Field(\n",
        "        None, description=\"Dependent care benefits (Box 10)\"\n",
        "    )\n",
        "    nonqualified_plans: float | None = Field(\n",
        "        None, description=\"Nonqualified plans (Box 11)\"\n",
        "    )\n",
        "    box_12a_code: str | None = Field(None, description=\"Code for amount in Box 12a\")\n",
        "    box_12a_amount: float | None = Field(None, description=\"Amount for Code in Box 12a\")\n",
        "    box_12b_code: str | None = Field(None, description=\"Code for amount in Box 12b\")\n",
        "    box_12b_amount: float | None = Field(None, description=\"Amount for Code in Box 12b\")\n",
        "    box_12c_code: str | None = Field(None, description=\"Code for amount in Box 12c\")\n",
        "    box_12c_amount: float | None = Field(None, description=\"Amount for Code in Box 12c\")\n",
        "    box_12d_code: str | None = Field(None, description=\"Code for amount in Box 12d\")\n",
        "    box_12d_amount: float | None = Field(None, description=\"Amount for Code in Box 12d\")\n",
        "    statutory_employee: bool = Field(\n",
        "        False, description=\"Indicates if Statutory Employee\"\n",
        "    )\n",
        "    retirement_plan: bool = Field(False, description=\"Indicates if Retirement Plan\")\n",
        "    third_party_sick_pay: float | None = Field(\n",
        "        None, description=\"Third-party sick pay (Box 13)\"\n",
        "    )\n",
        "    other: str | None = Field(None, description=\"Other (Box 14)\")\n",
        "\n",
        "    employer_name: str = Field(..., description=\"Employer's Name\")\n",
        "    employer_address: str = Field(..., description=\"Employer's Address\")\n",
        "    employer_city: str = Field(..., description=\"Employer's City\")\n",
        "    employer_state: str = Field(..., description=\"Employer's State (abbreviation)\")\n",
        "    employer_zip: str = Field(..., description=\"Employer's Zip Code\")\n",
        "\n",
        "    employee_name: str = Field(..., description=\"Employee's Name\")\n",
        "    employee_address: str = Field(..., description=\"Employee's Address\")\n",
        "    employee_city: str = Field(..., description=\"Employee's City\")\n",
        "    employee_state: str = Field(..., description=\"Employee's State (abbreviation)\")\n",
        "    employee_zip: str = Field(..., description=\"Employee's Zip Code\")\n",
        "\n",
        "    state: str | None = Field(None, description=\"State (if applicable)\")\n",
        "    state_employer_id: str | None = Field(\n",
        "        None, description=\"State Employer ID (if applicable)\"\n",
        "    )\n",
        "    state_wages: float | None = Field(None, description=\"State Wages (if applicable)\")\n",
        "    state_income_tax: float | None = Field(\n",
        "        None, description=\"State Income Tax (if applicable)\"\n",
        "    )\n",
        "\n",
        "\n",
        "class Form1099DIV(BaseModel):\n",
        "    \"\"\"\n",
        "    Pydantic class representing data extracted from Form 1099-DIV (Dividends and Distributions).\n",
        "    \"\"\"\n",
        "\n",
        "    payer_name: str | None = Field(\n",
        "        None, description=\"Name of the payer (company distributing dividends).\"\n",
        "    )\n",
        "    payer_street_address: str | None = Field(\n",
        "        None, description=\"Payer's street address.\"\n",
        "    )\n",
        "    payer_city: str | None = Field(None, description=\"Payer's city.\")\n",
        "    payer_state: str | None = Field(None, description=\"Payer's state.\")\n",
        "    payer_zip: str | None = Field(None, description=\"Payer's zip code.\")\n",
        "    payer_telephone: str | None = Field(None, description=\"Payer's telephone number.\")\n",
        "    payer_tin: str | None = Field(\n",
        "        None,\n",
        "        description=\"Payer's Taxpayer Identification Number (TIN).\",\n",
        "        alias=\"payer_id\",\n",
        "    )\n",
        "\n",
        "    recipient_name: str | None = Field(None, description=\"Recipient's (your) name.\")\n",
        "    recipient_street_address: str | None = Field(\n",
        "        None, description=\"Recipient's street address.\"\n",
        "    )\n",
        "    recipient_city: str | None = Field(None, description=\"Recipient's city.\")\n",
        "    recipient_state: str | None = Field(None, description=\"Recipient's state.\")\n",
        "    recipient_zip: str | None = Field(None, description=\"Recipient's zip code.\")\n",
        "    recipient_identification_number: str | None = Field(\n",
        "        None,\n",
        "        description=\"Recipient's Taxpayer Identification Number (TIN) (usually your SSN).\",\n",
        "        alias=\"recipient_id\",\n",
        "    )\n",
        "    account_number: str | None = Field(\n",
        "        None, description=\"Recipient's account number (if applicable).\"\n",
        "    )\n",
        "\n",
        "    # Box Values\n",
        "    box_1a_total_ordinary_dividends: float | None = Field(\n",
        "        None, description=\"Box 1a: Total Ordinary Dividends.\"\n",
        "    )\n",
        "    box_1b_qualified_dividends: float | None = Field(\n",
        "        None, description=\"Box 1b: Qualified Dividends.\"\n",
        "    )\n",
        "    box_2a_total_capital_gain_distributions: float | None = Field(\n",
        "        None, description=\"Box 2a: Total Capital Gain Distributions.\"\n",
        "    )\n",
        "    box_2b_unrecaptured_section_1250_gain: float | None = Field(\n",
        "        None, description=\"Box 2b: Unrecaptured Section 1250 Gain.\"\n",
        "    )\n",
        "    box_2c_section_1202_gain: float | None = Field(\n",
        "        None, description=\"Box 2c: Section 1202 Gain.\"\n",
        "    )\n",
        "    box_2d_collectibles_28_percent_rate_gain: float | None = Field(\n",
        "        None, description=\"Box 2d: Collectibles (28%) Rate Gain\"\n",
        "    )\n",
        "    box_3_nondividend_distributions: float | None = Field(\n",
        "        None, description=\"Box 3: Nondividend Distributions.\"\n",
        "    )\n",
        "    box_4_federal_income_tax_withheld: float | None = Field(\n",
        "        None, description=\"Box 4: Federal Income Tax Withheld.\"\n",
        "    )\n",
        "    box_5_section_199A_dividends: float | None = Field(\n",
        "        None, description=\"Box 5: Section 199A Dividends.\"\n",
        "    )\n",
        "    # Note Box 6 is not needed as it only notes if its a section 199A distribution\n",
        "\n",
        "    foreign_tax_paid: float | None = Field(\n",
        "        None,\n",
        "        description=\"Foreign tax Paid (If any is marked by a boolean in the additional box section)\",\n",
        "    )\n",
        "\n",
        "    foreign_country: str | None = Field(None, description=\"Name of Foreign Country\")\n",
        "\n",
        "\n",
        "class Form1099INT(BaseModel):\n",
        "    \"\"\"\n",
        "    Pydantic class representing data extracted from a Form 1099-INT (Interest Income).\n",
        "    \"\"\"\n",
        "\n",
        "    payer_name: str = Field(..., description=\"Name of the payer (bank, institution)\")\n",
        "    payer_tin: str = Field(\n",
        "        ...,\n",
        "        description=\"Payer's Taxpayer Identification Number (TIN)\",\n",
        "        alias=\"payer_tax_id\",\n",
        "    )  # Added alias\n",
        "    recipient_name: str = Field(..., description=\"Recipient's Name\")\n",
        "    recipient_tin: str = Field(\n",
        "        ...,\n",
        "        description=\"Recipient's Taxpayer Identification Number (TIN)\",\n",
        "        alias=\"recipient_tax_id\",\n",
        "    )  # Added alias\n",
        "    recipient_address: str = Field(..., description=\"Recipient's Address\")\n",
        "    recipient_city_state_zip: str = Field(\n",
        "        ..., description=\"Recipient's City, State, and Zip Code\"\n",
        "    )\n",
        "\n",
        "    box_1_interest_income: float = Field(..., description=\"Box 1: Interest Income\")\n",
        "    box_2_early_withdrawal_penalty: float | None = Field(\n",
        "        None, description=\"Box 2: Early Withdrawal Penalty\"\n",
        "    )\n",
        "    box_3_interest_us_savings_bonds_treas_obligations: float | None = Field(\n",
        "        None,\n",
        "        description=\"Box 3: Interest on U.S. Savings Bonds and Treasury Obligations\",\n",
        "    )\n",
        "    box_4_federal_income_tax_withheld: float | None = Field(\n",
        "        None, description=\"Box 4: Federal Income Tax Withheld\"\n",
        "    )\n",
        "    box_5_investment_expenses: float | None = Field(\n",
        "        None, description=\"Box 5: Investment Expenses\"\n",
        "    )\n",
        "    box_6_foreign_tax_paid: float | None = Field(\n",
        "        None, description=\"Box 6: Foreign Tax Paid\"\n",
        "    )\n",
        "    box_7_foreign_country_or_us_possession: str | None = Field(\n",
        "        None, description=\"Box 7: Foreign Country or U.S. Possession\"\n",
        "    )\n",
        "\n",
        "    account_number: str | None = Field(\n",
        "        None, description=\"Account Number (may be truncated)\"\n",
        "    )\n",
        "    form_year: int | None = Field(None, description=\"Year the form applies to\")\n",
        "    payer_street_address: str | None = Field(None, description=\"Payer's Street Address\")\n",
        "    payer_city_state_zip: str | None = Field(\n",
        "        None, description=\"Payer's City, State, and Zip Code\"\n",
        "    )\n",
        "\n",
        "\n",
        "document_mapping: dict[DocumentType, BaseModel] = {\n",
        "    DocumentType.W_2: FormW2,\n",
        "    DocumentType._1099_DIV: Form1099DIV,\n",
        "    DocumentType._1099_INT: Form1099INT,\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7ea36f19d9f7"
      },
      "source": [
        "### Define the Gemini prompt\n",
        "\n",
        "Here's the prompt we'll use with Gemini to extract the information we need."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "281b1ff1b0b1"
      },
      "outputs": [],
      "source": [
        "def extract_document(row: pd.Series) -> dict:\n",
        "    response = client.models.generate_content(\n",
        "        model=MODEL_ID,\n",
        "        contents=[\n",
        "            f\"Extract from the following {row['classification'].value} document.\",\n",
        "            Part.from_uri(\n",
        "                file_uri=row[\"file_uri\"],\n",
        "                mime_type=PDF_MIME_TYPE,\n",
        "            ),\n",
        "        ],\n",
        "        config=GenerateContentConfig(\n",
        "            system_instruction=\"\"\"You are an expert in United States Tax Forms. Given a document, extract fields for income tax filing.\"\"\",\n",
        "            temperature=0,\n",
        "            response_schema=document_mapping.get(row[\"classification\"]),\n",
        "            response_mime_type=JSON_MIME_TYPE,\n",
        "        ),\n",
        "    )\n",
        "    print(row[\"file_uri\"])\n",
        "    print(response.parsed)\n",
        "    return response.parsed.model_dump()\n",
        "\n",
        "\n",
        "tax_documents[\"extraction\"] = tax_documents.apply(extract_document, axis=1)\n",
        "\n",
        "# Normalize and flatten the extracted fields\n",
        "extracted_df = pd.json_normalize(tax_documents[\"extraction\"])\n",
        "\n",
        "# Merge the extracted fields back into the original dataframe\n",
        "tax_documents = tax_documents.drop(columns=[\"extraction\"]).join(extracted_df)\n",
        "\n",
        "display(tax_documents)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8f9bd6bce1c9"
      },
      "source": [
        "Now, we'll load the data to a CSV for further processing and tax calculation."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "89808e81a5c9"
      },
      "outputs": [],
      "source": [
        "tax_documents.to_csv(\"tax_data.csv\")"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "tax_automation.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
